import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision.models import resnet101
 
# preprocessing dependencies
from advertorch.utils import predict_from_logits
from advertorch.utils import NormalizeByChannelMeanStd
from advertorch_examples.utils import ImageNetClassNameLookup
from advertorch_examples.utils import get_panda_image
from advertorch_examples.utils import bhwc2bchw
from advertorch_examples.utils import bchw2bhwc


# load the attack class
from advertorch.attacks import LinfPGDAttack

'''
主要参考tutorial_attack_imagenet.ipynb
单张对抗图片生成
'''
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 1. load the model
    # 预处理
    normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    model = resnet101(pretrained=True)   # 直接加载预训练模型（网络结构 + 包含训练好的参数）
    model.eval()
    model = nn.Sequential(normalize, model) # 先经过标准化处理，再放入模型中
    model = model.to(device)

    # 2. load the image and index for true label
    np_img = get_panda_image()     # -----重写图片读取函数
    # 图片和标签都需要时tensor格式
    img = torch.tensor(bhwc2bchw(np_img))[None, :, :, :].float().to(device)
    label = torch.tensor([388, ]).long().to(device) # true label
    imagenet_label2classname = ImageNetClassNameLookup()

  
    # 2个功能函数
    def tensor2npimg(tensor):
        '''
        将tensor格式的图片转换为np格式
        '''
        return bchw2bhwc(tensor[0].cpu().numpy())
    
    def _show_images(enhance=127):
        '''
        展示原图、扰动、对抗生成的图片
        输入：
            advimg: 对抗生成的图片
            img: 原图
        '''
        # 对抗生成图片
        np_advimg = tensor2npimg(advimg)
        # 扰动
        np_perturb = tensor2npimg(advimg - img)
    
        # 取出图片预测的标签名
        pred = imagenet_label2classname(predict_from_logits(model(img)))
        advpred = imagenet_label2classname(predict_from_logits(model(advimg)))
    
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(np_img)                      # 原图
        
        plt.axis("off")
        plt.title("original image\n prediction: {}".format(pred))
        plt.subplot(1, 3, 2)
        plt.imshow(np_perturb * enhance + 0.5)  # 扰动图【将扰动放大显示】
        
        plt.axis("off")
        plt.title("The perturbation")
        plt.subplot(1, 3, 3)
        plt.imshow(np_advimg)                   # 对抗图
        plt.axis("off")
        plt.title("Perturbed image\n prediction: {}".format(advpred))

        plt.savefig('panda_advimg.jpg')        # 保存图片
        plt.show()

    # 3. 生成对应的对抗图片
    adversary = LinfPGDAttack(
        model, 
        eps=1./255, eps_iter=1./255*2/40, nb_iter=40,
        rand_init=False, targeted=False)
    advimg = adversary.perturb(img, label)
    _show_images()



if __name__ == '__main__':
    main()
